# Copyright (c) 2019-2021 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import ext_ops

###########################################################################################


class Bmm1Function(torch.autograd.Function):

    @staticmethod
    def forward(ctx, batch1, batch2, seqlen, batch, maxseqlen, heads, embed,
                scale, stream, sync):
        ctx.save_for_backward(batch1, batch2, seqlen)
        ctx.batch = batch
        ctx.maxseqlen = maxseqlen
        ctx.heads = heads
        ctx.embed = embed
        ctx.scale = scale
        ctx.sync = sync
        ctx.stream = stream
        ntokens = seqlen.sum().item()
        ctx.ntokens = ntokens
        ntokens2 = 0
        for i in range(batch):
            ntokens2 += seqlen[i] * seqlen[i]

        output = torch.empty(ntokens2 * heads,
                             device="cuda",
                             dtype=torch.float16)
        ext_ops.FastBmm1Fprop(batch2.flatten().contiguous(),
                              batch1.flatten().contiguous(),
                              output.flatten().contiguous(), batch, seqlen,
                              heads, embed, scale, False, stream, sync)

        return output[:ntokens2 * heads]

    @staticmethod
    def backward(ctx, grad_output):

        batch1, batch2, seqlen = ctx.saved_tensors
        batch = ctx.batch
        maxseqlen = ctx.maxseqlen
        heads = ctx.heads
        embed = ctx.embed
        ntokens = ctx.ntokens

        grad_batch1 = torch.empty(ntokens,
                                  heads * embed,
                                  device="cuda",
                                  dtype=torch.float16)
        grad_batch2 = torch.empty(ntokens,
                                  heads * embed,
                                  device="cuda",
                                  dtype=torch.float16)

        ext_ops.FastBmm1Dgrad2(batch2.flatten().contiguous(),
                               grad_output.flatten().contiguous(),
                               grad_batch1.flatten().contiguous(), batch,
                               seqlen, heads, embed, ctx.scale, False,
                               ctx.stream, ctx.sync)
        ext_ops.FastBmm1Dgrad1(batch1.flatten().contiguous(),
                               grad_output.flatten().contiguous(),
                               grad_batch2.flatten().contiguous(), batch,
                               seqlen, heads, embed, ctx.scale, False,
                               ctx.stream, ctx.sync)

        return grad_batch1[:
                           ntokens], grad_batch2[:
                                                 ntokens], None, None, None, None, None, None, None, None


class Bmm1(torch.nn.Module):

    def __init__(self,
                 batch,
                 seqlen,
                 heads,
                 embed,
                 scale=False,
                 stream=True,
                 sync=True):
        super(Bmm1, self).__init__()

        self.heads = heads
        self.embed = embed
        self.maxseqlen = seqlen
        self.scale = scale
        self.sync = sync
        self.stream = stream

    def forward(self, batch1, batch2, batch, seqlen):
        return Bmm1Function.apply(batch1, batch2, seqlen, batch,
                                  self.maxseqlen, self.heads, self.embed,
                                  self.scale, self.stream, self.sync)


##########################################################################################


class Bmm1StridedFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, mixed, seqlen, batch, maxseqlen, heads, embed, scale,
                stream, sync, timers):
        ctx.save_for_backward(mixed, seqlen)
        ctx.batch = batch
        ctx.maxseqlen = maxseqlen
        ctx.heads = heads
        ctx.embed = embed
        ctx.scale = scale
        ctx.sync = sync
        ctx.stream = stream
        ctx.timers = timers
        ntokens = seqlen.sum().item()
        ctx.ntokens = ntokens
        ntokens2 = 0
        for i in range(batch):
            ntokens2 += seqlen[i] * seqlen[i]

        output = torch.empty(ntokens2 * heads,
                             device="cuda",
                             dtype=torch.float16)

        if timers: timers['start_fprop'].record()
        ext_ops.FastBmm1Fprop(mixed, mixed, output, batch, seqlen, heads,
                              embed, scale, True, stream, sync)

        if timers: timers['stop_fprop'].record()

        return output[:ntokens2 * heads], mixed

    @staticmethod
    #def backward(ctx, grad_output):
    def backward(ctx, grad_output, grad_mixed):

        mixed, seqlen = ctx.saved_tensors
        batch = ctx.batch
        maxseqlen = ctx.maxseqlen
        heads = ctx.heads
        embed = ctx.embed
        ntokens = ctx.ntokens

        #grad_mixed = torch.empty([ntokens,heads*3*embed], device="cuda", dtype=torch.float16)

        if ctx.timers: ctx.timers['start_dgrad'].record()
        ext_ops.FastBmm1Dgrad2(mixed, grad_output, grad_mixed, batch, seqlen,
                               heads, embed, ctx.scale, True, ctx.stream,
                               ctx.sync)
        if ctx.timers: ctx.timers['stop_dgrad'].record()
        if ctx.timers: ctx.timers['start_wgrad'].record()
        ext_ops.FastBmm1Dgrad1(mixed, grad_output, grad_mixed, batch, seqlen,
                               heads, embed, ctx.scale, True, ctx.stream,
                               ctx.sync)
        if ctx.timers: ctx.timers['stop_wgrad'].record()
        #return grad_mixed[:ntokens], None, None, None, None, None, None, None, None, None
        return grad_mixed[:
                          ntokens], grad_mixed, None, None, None, None, None, None, None, None, None


class Bmm1Strided(torch.nn.Module):

    def __init__(self,
                 batch,
                 seqlen,
                 heads,
                 embed,
                 scale=True,
                 stream=True,
                 sync=True,
                 timer=False):
        super(Bmm1Strided, self).__init__()

        self.heads = heads
        self.embed = embed
        self.maxseqlen = seqlen
        self.scale = scale
        self.sync = sync
        self.stream = stream
        if timer:
            self.timers = {
                'start_fprop': torch.cuda.Event(enable_timing=True),
                'start_dgrad': torch.cuda.Event(enable_timing=True),
                'start_wgrad': torch.cuda.Event(enable_timing=True),
                'stop_fprop': torch.cuda.Event(enable_timing=True),
                'stop_dgrad': torch.cuda.Event(enable_timing=True),
                'stop_wgrad': torch.cuda.Event(enable_timing=True)
            }
        else:
            self.timers = None

    def forward(self, mixed, batch, seqlen):
        return Bmm1StridedFunction.apply(mixed, seqlen, batch, self.maxseqlen,
                                         self.heads, self.embed, self.scale,
                                         self.stream, self.sync, self.timers)


###########################################################################################
